// Copyright Epic Games, Inc. All Rights Reserved.

#include "winsocktransport.h"

#if ZEN_WITH_PLUGINS

#	include <zencore/logging.h>
#	include <zencore/scopeguard.h>
#	include <zencore/workthreadpool.h>

ZEN_THIRD_PARTY_INCLUDES_START
#	include <asio.hpp>
ZEN_THIRD_PARTY_INCLUDES_END

#	if ZEN_PLATFORM_WINDOWS
#		include <zencore/windows.h>
ZEN_THIRD_PARTY_INCLUDES_START
#		include <mstcpip.h>
ZEN_THIRD_PARTY_INCLUDES_END
#	endif

#	include <fmt/format.h>

#	include <memory>
#	include <thread>

namespace zen {

struct AsioTransportAcceptor;

class AsioTransportPlugin : public TransportPlugin, RefCounted
{
public:
	AsioTransportPlugin();
	~AsioTransportPlugin();

	virtual uint32_t	AddRef() const override;
	virtual uint32_t	Release() const override;
	virtual void		Configure(const char* OptionTag, const char* OptionValue) override;
	virtual void		Initialize(TransportServer* ServerInterface) override;
	virtual void		Shutdown() override;
	virtual const char* GetDebugName() override { return nullptr; }
	virtual bool		IsAvailable() override;

private:
	bool	 m_IsOk		   = true;
	uint16_t m_BasePort	   = 8558;
	int		 m_ThreadCount = 0;

	asio::io_service					   m_IoService;
	asio::io_service::work				   m_Work{m_IoService};
	std::unique_ptr<AsioTransportAcceptor> m_Acceptor;
	std::vector<std::thread>			   m_ThreadPool;
};

struct AsioTransportConnection : public TransportConnection, std::enable_shared_from_this<AsioTransportConnection>
{
	AsioTransportConnection(std::unique_ptr<asio::ip::tcp::socket>&& Socket);
	~AsioTransportConnection();

	void Initialize(TransportServerConnection* ConnectionHandler);

	std::shared_ptr<AsioTransportConnection> AsSharedPtr() { return shared_from_this(); }

	// TransportConnectionInterface

	virtual int64_t		WriteBytes(const void* Buffer, size_t DataSize) override;
	virtual void		Shutdown(bool Receive, bool Transmit) override;
	virtual void		CloseConnection() override;
	virtual const char* GetDebugName() override { return nullptr; }

private:
	void EnqueueRead();
	void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount);
	void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount);

	Ref<TransportServerConnection>		   m_ConnectionHandler;
	asio::streambuf						   m_RequestBuffer;
	std::unique_ptr<asio::ip::tcp::socket> m_Socket;
	uint32_t							   m_ConnectionId = 0;
	std::atomic_flag					   m_IsTerminated{};
};

//////////////////////////////////////////////////////////////////////////

struct AsioTransportAcceptor
{
	AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort);
	~AsioTransportAcceptor();

	void Start();
	void RequestStop();

	inline int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); }

private:
	TransportServer*		m_ServerInterface = nullptr;
	asio::io_service&		m_IoService;
	asio::ip::tcp::acceptor m_Acceptor;
	std::atomic<bool>		m_IsStopped{false};

	void EnqueueAccept();
};

//////////////////////////////////////////////////////////////////////////

AsioTransportAcceptor::AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort)
: m_ServerInterface(ServerInterface)
, m_IoService(IoService)
, m_Acceptor(m_IoService, asio::ip::tcp::v6())
{
	m_Acceptor.set_option(asio::ip::v6_only(false));

#	if ZEN_PLATFORM_WINDOWS
	// Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms
	typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address;
	m_Acceptor.set_option(exclusive_address(true));
#	else
	m_Acceptor.set_option(asio::socket_base::reuse_address(false));
#	endif

	m_Acceptor.set_option(asio::ip::tcp::no_delay(true));
	m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024));
	m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024));

	uint16_t EffectivePort = BasePort;

	asio::error_code BindErrorCode;
	m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
	// Sharing violation implies the port is being used by another process
	for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset)
	{
		EffectivePort = BasePort + (PortOffset * 100);
		m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
	}
	if (BindErrorCode == asio::error::access_denied)
	{
		EffectivePort = 0;
		m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode);
	}
	if (BindErrorCode)
	{
		ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message());
	}

#	if ZEN_PLATFORM_WINDOWS
	// On Windows, loopback connections can take advantage of a faster code path optionally with this flag.
	// This must be used by both the client and server side, and is only effective in the absence of
	// Windows Filtering Platform (WFP) callouts which can be installed by security software.
	// https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path
	SOCKET NativeSocket				   = m_Acceptor.native_handle();
	int	   LoopbackOptionValue		   = 1;
	DWORD  OptionNumberOfBytesReturned = 0;
	WSAIoctl(NativeSocket,
			 SIO_LOOPBACK_FAST_PATH,
			 &LoopbackOptionValue,
			 sizeof(LoopbackOptionValue),
			 NULL,
			 0,
			 &OptionNumberOfBytesReturned,
			 0,
			 0);
#	endif

	ZEN_INFO("started asio transport at port: {}", EffectivePort);
}

AsioTransportAcceptor::~AsioTransportAcceptor()
{
}

void
AsioTransportAcceptor::Start()
{
	m_Acceptor.listen();

	EnqueueAccept();
}

void
AsioTransportAcceptor::RequestStop()
{
	m_IsStopped = true;
}

void
AsioTransportAcceptor::EnqueueAccept()
{
	auto				   SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService);
	asio::ip::tcp::socket& SocketRef = *SocketPtr.get();

	m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable {
		if (Ec)
		{
			ZEN_WARN("asio async_accept error ({}:{}): {}",
					 m_Acceptor.local_endpoint().address().to_string(),
					 m_Acceptor.local_endpoint().port(),
					 Ec.message());
		}
		else
		{
			// New connection established, pass socket ownership into connection object
			// and initiate request handling loop. The connection lifetime is
			// managed by the async read/write loop by passing the shared
			// reference to the callbacks.

			Socket->set_option(asio::ip::tcp::no_delay(true));
			Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024));
			Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024));

			auto Conn = std::make_shared<AsioTransportConnection>(std::move(Socket));
			Conn->Initialize(m_ServerInterface->CreateConnectionHandler(Conn.get()));
		}

		if (!m_IsStopped.load())
		{
			EnqueueAccept();
		}
		else
		{
			std::error_code CloseEc;
			m_Acceptor.close(CloseEc);

			if (CloseEc)
			{
				ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message());
			}
		}
	});
}

//////////////////////////////////////////////////////////////////////////

AsioTransportConnection::AsioTransportConnection(std::unique_ptr<asio::ip::tcp::socket>&& Socket) : m_Socket(std::move(Socket))
{
}

AsioTransportConnection::~AsioTransportConnection()
{
}

void
AsioTransportConnection::Initialize(TransportServerConnection* ConnectionHandler)
{
	m_ConnectionHandler = ConnectionHandler;

	EnqueueRead();
}

int64_t
AsioTransportConnection::WriteBytes(const void* Buffer, size_t DataSize)
{
	size_t WrittenBytes = asio::write(*m_Socket.get(), asio::const_buffer(Buffer, DataSize), asio::transfer_exactly(DataSize));

	return WrittenBytes;
}

void
AsioTransportConnection::Shutdown(bool Receive, bool Transmit)
{
	std::error_code Ec;
	if (Receive)
	{
		if (Transmit)
		{
			m_Socket->shutdown(asio::socket_base::shutdown_both, Ec);
		}
		else
		{
			m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec);
		}
	}
	else if (Transmit)
	{
		m_Socket->shutdown(asio::socket_base::shutdown_send, Ec);
	}
}

void
AsioTransportConnection::CloseConnection()
{
	if (m_IsTerminated.test())
	{
		return;
	}

	if (m_IsTerminated.test_and_set() == false)
	{
		Shutdown(true, true);

		std::error_code Ec;
		m_Socket->close(Ec);
	}
}

void
AsioTransportConnection::EnqueueRead()
{
	if (m_IsTerminated.test() == false)
	{
		m_RequestBuffer.prepare(64 * 1024);

		asio::async_read(
			*m_Socket.get(),
			m_RequestBuffer,
			asio::transfer_at_least(1),
			[Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); });
	}
}

void
AsioTransportConnection::OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount)
{
	ZEN_UNUSED(ByteCount);

	if (Ec)
	{
		if (!m_IsTerminated.test())
		{
			ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message());
		}

		const bool Receive	= true;
		const bool Transmit = true;
		return Shutdown(Receive, Transmit);
	}

	while (m_RequestBuffer.size())
	{
		const asio::const_buffer& InputBuffer = m_RequestBuffer.data();
		m_ConnectionHandler->OnBytesRead(InputBuffer.data(), InputBuffer.size());
		m_RequestBuffer.consume(InputBuffer.size());
	}

	EnqueueRead();
}

void
AsioTransportConnection::OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount)
{
	ZEN_UNUSED(ByteCount);

	if (Ec)
	{
		ZEN_WARN("on data sent ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message());

		const bool Receive	= true;
		const bool Transmit = true;
		return Shutdown(Receive, Transmit);
	}
}

//////////////////////////////////////////////////////////////////////////

AsioTransportPlugin::AsioTransportPlugin() : m_ThreadCount(Max(std::thread::hardware_concurrency(), 8u))
{
}

AsioTransportPlugin::~AsioTransportPlugin()
{
}

uint32_t
AsioTransportPlugin::AddRef() const
{
	return RefCounted::AddRef();
}

uint32_t
AsioTransportPlugin::Release() const
{
	return RefCounted::Release();
}

void
AsioTransportPlugin::Configure(const char* OptionTag, const char* OptionValue)
{
	using namespace std::literals;

	if (OptionTag == "port"sv)
	{
		if (auto PortNum = ParseInt<uint16_t>(OptionValue))
		{
			m_BasePort = *PortNum;
		}
	}
	else if (OptionTag == "threads"sv)
	{
		if (auto ThreadCount = ParseInt<int>(OptionValue))
		{
			m_ThreadCount = *ThreadCount;
		}
	}
	else
	{
		// Unknown configuration option
	}
}

void
AsioTransportPlugin::Initialize(TransportServer* ServerInterface)
{
	ZEN_ASSERT(m_ThreadCount > 0);
	ZEN_ASSERT(ServerInterface);

	ZEN_INFO("starting asio http with {} service threads", m_ThreadCount);

	m_Acceptor.reset(new AsioTransportAcceptor(ServerInterface, m_IoService, m_BasePort));
	m_Acceptor->Start();

	// This should consist of a set of minimum threads and grow on demand to
	// meet concurrency needs? Right now we end up allocating a large number
	// of threads even if we never end up using all of them, which seems
	// wasteful. It's also not clear how the demand for concurrency should
	// be balanced with the engine side - ideally we'd have some kind of
	// global scheduling to prevent one side from starving the other side
	// and thus preventing progress. Or at the very least, thread priorities
	// should be considered

	for (int i = 0; i < m_ThreadCount; ++i)
	{
		m_ThreadPool.emplace_back([this, ThreadNumber = i + 1] {
			SetCurrentThreadName(fmt::format("asio_thr_{}", ThreadNumber));

			try
			{
				m_IoService.run();
			}
			catch (const std::exception& e)
			{
				ZEN_ERROR("exception caught in asio event loop: {}", e.what());
			}
		});
	}

	ZEN_INFO("asio http transport started (port {})", m_Acceptor->GetAcceptPort());
}

void
AsioTransportPlugin::Shutdown()
{
	m_Acceptor->RequestStop();
	m_IoService.stop();

	for (auto& Thread : m_ThreadPool)
	{
		Thread.join();
	}
}

bool
AsioTransportPlugin::IsAvailable()
{
	return true;
}

TransportPlugin*
CreateAsioTransportPlugin()
{
	return new AsioTransportPlugin();
}

}  // namespace zen

#endif
